﻿// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.ComponentModel;
using System.Drawing;
using System.Formats.Nrbf;
using System.Private.Windows.BinaryFormat;
using System.Reflection;
using System.Reflection.Metadata;
using System.Runtime.ExceptionServices;
using System.Runtime.Serialization;
using System.Text.Json;

namespace System.Private.Windows.Nrbf;

internal static class SerializationRecordExtensions
{
    /// <inheritdoc cref="DecodeNrbf(Stream, out IReadOnlyDictionary{SerializationRecordId, SerializationRecord})"/>
    internal static SerializationRecord DecodeNrbf(this Stream stream) => stream.DecodeNrbf(out _);

    /// <summary>
    ///  Decodes a NRBF stream, converting any exceptions to <see cref="SerializationException"/>.
    /// </summary>
    internal static SerializationRecord DecodeNrbf(this Stream stream, out IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap)
    {
        try
        {
            return NrbfDecoder.Decode(stream, out recordMap, leaveOpen: true);
        }
        catch (Exception ex) when (ex is ArgumentException or InvalidCastException or ArithmeticException or IOException or TargetInvocationException)
        {
            // Make the exception easier to catch, but retain the original stack trace.
            throw ex.ConvertToSerializationException();
        }
    }

    /// <summary>
    ///  Deserializes the <see cref="SerializationRecord"/> to an object.
    /// </summary>
    [RequiresUnreferencedCode("Ultimately calls resolver for type names in the data.")]
    public static object? Deserialize(
        this SerializationRecord rootRecord,
        IReadOnlyDictionary<SerializationRecordId, SerializationRecord> recordMap,
        ITypeResolver typeResolver)
    {
        DeserializationOptions options = new()
        {
            TypeResolver = typeResolver
        };

        try
        {
            return Deserializer.Deserialize(rootRecord.Id, recordMap, options);
        }
        catch (SerializationException ex)
        {
            throw ExceptionDispatchInfo.SetRemoteStackTrace(
                new NotSupportedException(ex.Message, ex), ex.StackTrace ?? string.Empty);
        }
    }

    internal delegate bool TryGetDelegate(SerializationRecord record, [NotNullWhen(true)] out object? value);

    internal static bool TryGet(TryGetDelegate get, SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        try
        {
            return get(record, out value);
        }
        catch (Exception ex) when (ex is KeyNotFoundException or InvalidCastException)
        {
            // This should only really happen with corrupted data.
            Debug.Fail(ex.Message);
            value = default;
            return false;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="Point"/>.
    /// </summary>
    public static bool TryGetPoint(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = default;

            if (record is not ClassRecord classInfo
                 || !classInfo.TypeNameMatches(typeof(Point))
                 || !classInfo.HasMember("x")
                 || !classInfo.HasMember("y"))
            {
                return false;
            }

            value = new Point(classInfo.GetInt32("x"), classInfo.GetInt32("y"));

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="Size"/>.
    /// </summary>
    public static bool TryGetSize(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = default;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(Size))
                || !classInfo.HasMember("height")
                || !classInfo.HasMember("width"))
            {
                return false;
            }

            value = new Size(classInfo.GetInt32("width"), classInfo.GetInt32("height"));

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="Rectangle"/>.
    /// </summary>
    public static bool TryGetRectangle(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = default;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(Rectangle))
                || !classInfo.HasMember("x")
                || !classInfo.HasMember("y")
                || !classInfo.HasMember("width")
                || !classInfo.HasMember("height"))
            {
                return false;
            }

            value = new Rectangle(
                classInfo.GetInt32("x"),
                classInfo.GetInt32("y"),
                classInfo.GetInt32("width"),
                classInfo.GetInt32("height"));

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="PointF"/>.
    /// </summary>
    public static bool TryGetPointF(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = default;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(PointF))
                || !classInfo.HasMember("x")
                || !classInfo.HasMember("y"))
            {
                return false;
            }

            value = new PointF(classInfo.GetSingle("x"), classInfo.GetSingle("y"));

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="SizeF"/>.
    /// </summary>
    public static bool TryGetSizeF(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = default;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(SizeF))
                || !classInfo.HasMember("height")
                || !classInfo.HasMember("width"))
            {
                return false;
            }

            value = new SizeF(classInfo.GetSingle("width"), classInfo.GetSingle("height"));

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="RectangleF"/>.
    /// </summary>
    public static bool TryGetRectangleF(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = default;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(RectangleF))
                || !classInfo.HasMember("x")
                || !classInfo.HasMember("y")
                || !classInfo.HasMember("width")
                || !classInfo.HasMember("height"))
            {
                return false;
            }

            value = new RectangleF(
                classInfo.GetSingle("x"),
                classInfo.GetSingle("y"),
                classInfo.GetSingle("width"),
                classInfo.GetSingle("height"));

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="Color"/>.
    /// </summary>
    public static bool TryGetColor(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = default;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(Color))
                || !classInfo.HasMember("name")
                || !classInfo.HasMember("value")
                || !classInfo.HasMember("knownColor")
                || !classInfo.HasMember("state"))
            {
                return false;
            }

            ConstructorInfo? ctor = typeof(Color).GetConstructor(
                BindingFlags.Instance | BindingFlags.NonPublic, [typeof(long), typeof(short), typeof(string), typeof(KnownColor)]);
            if (ctor is null)
            {
                return false;
            }

            value = ctor.Invoke([
                classInfo.GetInt64("value"),
                classInfo.GetInt16("state"),
                classInfo.GetString("name"),
                (KnownColor)Convert.ToInt32(classInfo.GetInt16("knownColor"))]);

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a primitive type or string.
    /// </summary>
    /// <returns><see langword="true"/> if this represented a primitive type or string.</returns>
    public static bool TryGetPrimitiveType(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            if (record.RecordType is SerializationRecordType.BinaryObjectString)
            {
                value = ((PrimitiveTypeRecord<string>)record).Value;
                return true;
            }
            else if (record.RecordType is SerializationRecordType.MemberPrimitiveTyped)
            {
                value = GetMemberPrimitiveTypedValue(record);
                return true;
            }

            value = null;
            return false;
        }
    }

    internal static object GetMemberPrimitiveTypedValue(this SerializationRecord record)
    {
        Debug.Assert(record.RecordType is SerializationRecordType.MemberPrimitiveTyped);
        return ((PrimitiveTypeRecord)record).Value;
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="List{T}"/> of <see cref="PrimitiveType"/>.
    /// </summary>
    public static bool TryGetPrimitiveList(this SerializationRecord record, [NotNullWhen(true)] out object? list)
    {
        return TryGet(Get, record, out list);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? list)
        {
            list = null;

            if (record is not ClassRecord classInfo
                || !classInfo.HasMember("_items")
                || !classInfo.HasMember("_size")
                || classInfo.GetRawValue("_size") is not int size
                || !classInfo.TypeName.IsConstructedGenericType
                || classInfo.TypeName.GetGenericTypeDefinition().Name != typeof(List<>).Name
                || classInfo.TypeName.GetGenericArguments().Length != 1
                || classInfo.GetRawValue("_items") is not ArrayRecord arrayRecord
                || !IsPrimitiveArrayRecord(arrayRecord))
            {
                return false;
            }

            // BinaryFormatter serializes the entire backing array, so we need to trim it down to the size of the list.
            list = arrayRecord switch
            {
                SZArrayRecord<string> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<bool> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<byte> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<sbyte> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<char> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<short> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<ushort> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<int> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<uint> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<long> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<ulong> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<float> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<double> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<decimal> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<TimeSpan> ar => ar.GetArray().CreateTrimmedList(size),
                SZArrayRecord<DateTime> ar => ar.GetArray().CreateTrimmedList(size),
                _ => throw new InvalidOperationException()
            };

            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a <see cref="ArrayList"/> of <see cref="PrimitiveType"/> values.
    /// </summary>
    public static bool TryGetPrimitiveArrayList(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            value = null;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(ArrayList))
                || !classInfo.HasMember("_items")
                || !classInfo.HasMember("_size")
                || classInfo.GetRawValue("_size") is not int size
                || classInfo.GetRawValue("_items") is not SZArrayRecord<SerializationRecord> arrayRecord
                || size > arrayRecord.Length)
            {
                return false;
            }

            ArrayList arrayList = new(size);
            SerializationRecord?[] array = arrayRecord.GetArray();
            for (int i = 0; i < size; i++)
            {
                SerializationRecord? elementRecord = array[i];
                if (elementRecord is null)
                {
                    arrayList.Add(null);
                }
                else if (elementRecord is PrimitiveTypeRecord primitiveTypeRecord)
                {
                    arrayList.Add(primitiveTypeRecord.Value);
                }
                else
                {
                    // It was a complex type (represented as a ClassRecord or an ArrayRecord)
                    return false;
                }
            }

            value = arrayList;
            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as an <see cref="Array"/> of primitive types.
    /// </summary>
    public static bool TryGetPrimitiveArray(this SerializationRecord record, [NotNullWhen(true)] out object? value)
    {
        return TryGet(Get, record, out value);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? value)
        {
            if (!IsPrimitiveArrayRecord(record))
            {
                value = null;
                return false;
            }

            value = record switch
            {
                SZArrayRecord<string> ar => ar.GetArray(),
                SZArrayRecord<bool> ar => ar.GetArray(),
                SZArrayRecord<byte> ar => ar.GetArray(),
                SZArrayRecord<sbyte> ar => ar.GetArray(),
                SZArrayRecord<char> ar => ar.GetArray(),
                SZArrayRecord<short> ar => ar.GetArray(),
                SZArrayRecord<ushort> ar => ar.GetArray(),
                SZArrayRecord<int> ar => ar.GetArray(),
                SZArrayRecord<uint> ar => ar.GetArray(),
                SZArrayRecord<long> ar => ar.GetArray(),
                SZArrayRecord<ulong> ar => ar.GetArray(),
                SZArrayRecord<float> ar => ar.GetArray(),
                SZArrayRecord<double> ar => ar.GetArray(),
                SZArrayRecord<decimal> ar => ar.GetArray(),
                SZArrayRecord<TimeSpan> ar => ar.GetArray(),
                SZArrayRecord<DateTime> ar => ar.GetArray(),
                _ => throw new InvalidOperationException()
            };

            return value is not null;
        }
    }

    /// <summary>
    ///  Tries to get this object as a binary formatted <see cref="Hashtable"/> of <see cref="PrimitiveType"/> keys and values.
    /// </summary>
    public static bool TryGetPrimitiveHashtable(this SerializationRecord record, [NotNullWhen(true)] out Hashtable? hashtable)
    {
        bool success = record.TryGetPrimitiveHashtable(out object? value);
        hashtable = (Hashtable?)value;
        return success;
    }

    /// <summary>
    ///  Tries to get this object as a binary formatted <see cref="Hashtable"/> of <see cref="PrimitiveType"/> keys and values.
    /// </summary>
    public static bool TryGetPrimitiveHashtable(this SerializationRecord record, [NotNullWhen(true)] out object? hashtable)
    {
        return TryGet(Get, record, out hashtable);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? hashtable)
        {
            hashtable = null;

            if (record.RecordType != SerializationRecordType.SystemClassWithMembersAndTypes
                || record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(Hashtable))
                || !classInfo.HasMember("Keys")
                || !classInfo.HasMember("Values")
                // Note that hashtables with custom comparers and/or hash code providers will have non null Comparer
                || classInfo.GetSerializationRecord("Comparer") is not null
                || classInfo.GetSerializationRecord("Keys") is not SZArrayRecord<SerializationRecord?> keysRecord
                || classInfo.GetSerializationRecord("Values") is not SZArrayRecord<SerializationRecord?> valuesRecord
                || keysRecord.Length != valuesRecord.Length)
            {
                return false;
            }

            Hashtable temp = new(keysRecord.Length);
            SerializationRecord?[] keys = keysRecord.GetArray();
            SerializationRecord?[] values = valuesRecord.GetArray();
            for (int i = 0; i < keys.Length; i++)
            {
                SerializationRecord? key = keys[i];
                SerializationRecord? value = values[i];

                if (key is null || key is not PrimitiveTypeRecord primitiveKey)
                {
                    return false;
                }

                if (value is null)
                {
                    temp[primitiveKey.Value] = null; // null values are allowed
                }
                else if (value is PrimitiveTypeRecord primitiveValue)
                {
                    temp[primitiveKey.Value] = primitiveValue.Value;
                }
                else
                {
                    // It was a complex type (represented as a ClassRecord or an ArrayRecord)
                    return false;
                }
            }

            hashtable = temp;
            return true;
        }
    }

    /// <summary>
    ///  Tries to get this object as a binary formatted <see cref="NotSupportedException"/>.
    /// </summary>
    public static bool TryGetNotSupportedException(
        this SerializationRecord record,
        out object? exception)
    {
        return TryGet(Get, record, out exception);

        static bool Get(SerializationRecord record, [NotNullWhen(true)] out object? exception)
        {
            exception = null;

            if (record is not ClassRecord classInfo
                || !classInfo.TypeNameMatches(typeof(NotSupportedException)))
            {
                return false;
            }

            exception = new NotSupportedException(classInfo.GetString("Message"));
            return true;
        }
    }

    /// <summary>
    ///  Try to get a supported .NET type object (not WinForms) that has no <see cref="TypeConverter"/>.
    /// </summary>
    public static bool TryGetFrameworkObject(
        this SerializationRecord record,
        [NotNullWhen(true)] out object? value)
        => record.TryGetPrimitiveType(out value)
            || record.TryGetPrimitiveList(out value)
            || record.TryGetPrimitiveArray(out value)
            || record.TryGetPrimitiveArrayList(out value)
            || record.TryGetPrimitiveHashtable(out value)
            || record.TryGetRectangleF(out value)
            || record.TryGetPointF(out value)
            || record.TryGetNotSupportedException(out value);

    /// <summary>
    ///  Try to get a supported System.Drawing.Primitives object that has a <see cref="TypeConverter"/>.
    ///  This method is used for Clipboard payload deserialization. ResX deserialization uses <see cref="TypeConverter"/>s for these types.
    /// </summary>
    public static bool TryGetDrawingPrimitivesObject(
        this SerializationRecord record,
        [NotNullWhen(true)] out object? value) =>
        record.TryGetPoint(out value)
            || record.TryGetSize(out value)
            || record.TryGetSizeF(out value)
            || record.TryGetRectangle(out value)
            || record.TryGetColor(out value);

    private static bool IsPrimitiveArrayRecord(SerializationRecord serializationRecord) =>
        serializationRecord.RecordType is SerializationRecordType.ArraySingleString or SerializationRecordType.ArraySinglePrimitive;

    /// <summary>
    ///  Tries to deserialize this object if it was serialized as JSON.
    /// </summary>
    /// <param name="object">
    ///  When <see langword="true"/> is returned, this will be the deserialized object or <see langword="default"/>
    ///  if is was a JSON serialized type and the type did not match.
    /// </param>
    /// <exception cref="SerializationException">
    ///  If the data was supposed to be our <see cref="JsonData{T}"/>, but was serialized incorrectly.
    /// </exception>
    /// <exception cref="NotSupportedException">If an exception occurred while JSON deserializing.</exception>
    /// <devdoc>
    ///  We don't try to resolve all types in the graph of a type when deserializing JSON as the default options
    ///  should not present the same risks as BinaryFormatter deserialization. JSON binding is just for the root
    ///  type.
    /// </devdoc>
    internal static (bool isJsonData, bool isValidType) TryGetObjectFromJson<T>(
        this SerializationRecord record,
        ITypeResolver resolver,
        out T? @object)
    {
        @object = default;

        if (record.TypeName.AssemblyName?.FullName != IJsonData.CustomAssemblyName)
        {
            // The data was not serialized as JSON.
            return (isJsonData: false, isValidType: false);
        }

        if (record is not ClassRecord types
            || types.GetRawValue("<JsonBytes>k__BackingField") is not SZArrayRecord<byte> byteData
            || types.GetRawValue("<InnerTypeAssemblyQualifiedName>k__BackingField") is not string assemblyQualifiedTypeName
            || !TypeName.TryParse(assemblyQualifiedTypeName, out TypeName? typeName))
        {
            // This is supposed to be JsonData, but somehow the binary formatted data is corrupt.
            throw new SerializationException(SR.ClipboardOrDragDrop_JsonDeserializationFailed);
        }

        Type? boundType = resolver.BindToType(typeName);
        if (!boundType.IsAssignableTo(typeof(T)))
        {
            // Not the type the caller asked for.
            return (isJsonData: true, isValidType: false);
        }

        if (boundType == typeof(object))
        {
            // Special case for deserializing to object. JsonSerializer.Deserializer<object> gives back a JsonElement.
            // We want to do this for the untyped APIs as downlevel apps would be able to read the object via
            // GetData (via the IObjectReference behavior in BinaryFormatter). Preventing it here would be difficult
            // to explain. Doing this also facilitates moving to the typed APIs for existing data that is JSON
            // serializable. You can simply switch to SerializeAsJson and know existing consumers will not be broken.

            boundType = Type.GetType(
                assemblyQualifiedTypeName,
                throwOnError: false);

            // Fall back to just the full name.
            boundType ??= Type.GetType(
                typeName.FullName,
                throwOnError: true);
        }

        // Let the original exception bubble up if deserialization fails.
        if (boundType is not null)
        {
            Utf8JsonReader reader = new(byteData.GetArray());
            @object = (T?)JsonSerializer.Deserialize(ref reader, boundType);
        }

        return (isJsonData: true, isValidType: @object is not null);
    }
}
